{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f705f4be70e9"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4d1eedd3d0f1"
      },
      "source": [
        "# Semantic Router Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "75662f324f65"
      },
      "source": [
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fagents%2Fgenai-experience-concierge%2Fagent-design-patterns%2Fsemantic-router.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/semantic-router.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3bf88396d108"
      },
      "source": [
        "| | | |\n",
        "|-|-|-|\n",
        "|Author(s) | [Ahmad Khan](https://github.com/Akhan221) | [Pablo Gaeta](https://github.com/pablofgaeta) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a77e180db2bb"
      },
      "source": [
        "## Overview\n",
        "\n",
        "### Introduction\n",
        "\n",
        "The semantic router pattern is an approach to dynamically pick one expert agent from a collection of candidates that is best fit to address the current user input.\n",
        "\n",
        "This demo uses an LLM-based intent detection classifier to route each user query to either a \"Retail Search\" or \"Customer Support\" expert assistant. The experts are mocked as simple Gemini calls with a system prompt for this demo, but represent an arbitrary actor that can share session history with all other sub-agents. For example, the customer support agent might be implemented with [Contact Center as a Service](https://cloud.google.com/solutions/contact-center-ai-platform) while the retail search assistant is built with Gemini and deployed on Cloud Run.\n",
        "\n",
        "The semantic router layer can provide a useful facade to enable a single interface for multiple different agent backends.\n",
        "\n",
        "### Key Components\n",
        "\n",
        "The key components of this Semantic Router Agent include:\n",
        "\n",
        "* **Language Model:** Gemini is used for intent detection and conversational response generation.\n",
        "* **State Management:** LangGraph manages the conversation flow and maintains the session state, including conversation history and sub-agent routing.\n",
        "* **Router Node:** This node is responsible for classifying user queries. It uses a Gemini model to analyze the input and determine whether it relates to retail search, customer service, or falls outside the agent's capabilities.\n",
        "\n",
        "### Workflow\n",
        "\n",
        "The workflow of the Semantic Router Agent can be summarized as follows:\n",
        "\n",
        "1.  The user provides an input query.\n",
        "2.  The **Router Node** classifies the query using a Gemini model, determining the appropriate sub-agent.\n",
        "3.  The query is routed to either the **Retail Search Node** or the **Customer Service Node**, depending on the classification. If the query is unsupported, it's routed to the **Post-Process Node** and a fallback message is provided.\n",
        "4.  The selected sub-agent (or the Post-Process Node for unsupported queries) generates a response.\n",
        "5.  The **Post-Process Node** updates the conversation history with the user input and the agent's response.\n",
        "6.  The agent is ready for the next user input, continuing the loop."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "93b9c58f24d9"
      },
      "source": [
        "## Get Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5027929de8f"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "34a6da1ef6ac"
      },
      "outputs": [],
      "source": [
        "%pip install -q langgraph langgraph-checkpoint google-genai"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f42d12d15616"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "06fd78d27773"
      },
      "outputs": [],
      "source": [
        "# import IPython\n",
        "\n",
        "# app = IPython.Application.instance()\n",
        "# app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e114f5653870"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "911453311a5d"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0724a3d2c4f9"
      },
      "source": [
        "## Notebook parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "c0af4bf2ab7b"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "REGION = \"us-central1\"  # @param {type:\"string\"}\n",
        "CHAT_MODEL_NAME = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}\n",
        "ROUTER_MODEL_NAME = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}\n",
        "MAX_TURN_HISTORY = 3  # @param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "59015bb54f4a"
      },
      "source": [
        "## Define the Semantic Router Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6a04ecba1630"
      },
      "source": [
        "### Import dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "7144f0d22ddc"
      },
      "outputs": [],
      "source": [
        "from collections.abc import AsyncIterator\n",
        "import datetime\n",
        "import enum\n",
        "import logging\n",
        "from typing import Literal, TypedDict\n",
        "import uuid\n",
        "\n",
        "from IPython import display as ipd\n",
        "from google import genai\n",
        "from google.genai import types as genai_types\n",
        "from langchain_core.runnables import config as lc_config\n",
        "from langgraph import graph\n",
        "from langgraph import types as lg_types\n",
        "from langgraph.checkpoint import memory as memory_checkpoint\n",
        "from langgraph.config import get_stream_writer\n",
        "import pydantic\n",
        "\n",
        "logger = logging.getLogger(__name__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f6563a070bad"
      },
      "source": [
        "### Define schemas\n",
        "\n",
        "Defines all of the schemas, constants, and types required for building the agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "57ca599e4cd2"
      },
      "outputs": [],
      "source": [
        "# Agent config settings\n",
        "\n",
        "\n",
        "class AgentConfig(pydantic.BaseModel):\n",
        "    \"\"\"Configuration settings for the agent, including project, region, and model details.\"\"\"\n",
        "\n",
        "    project: str\n",
        "    \"\"\"The Google Cloud project ID.\"\"\"\n",
        "    region: str\n",
        "    \"\"\"The Google Cloud region where the agent is deployed.\"\"\"\n",
        "    chat_model_name: str\n",
        "    \"\"\"The name of the Gemini chat model to use for generating responses.\"\"\"\n",
        "    router_model_name: str\n",
        "    \"\"\"The name of the Gemini model to use for routing user queries.\"\"\"\n",
        "    max_router_turn_history: int\n",
        "    \"\"\"The maximum number of turns to include in the router's context window.\"\"\"\n",
        "\n",
        "\n",
        "# Node names and literal types\n",
        "\n",
        "ROUTER_NODE_NAME = \"ROUTER\"\n",
        "\"\"\"The name of the router node in the LangGraph.\"\"\"\n",
        "RouterNodeTargetLiteral = Literal[\"ROUTER\"]\n",
        "\"\"\"Literal type for the router node target.\"\"\"\n",
        "\n",
        "RETAIL_NODE_NAME = \"RETAIL\"\n",
        "\"\"\"The name of the retail node in the LangGraph.\"\"\"\n",
        "RetailNodeTargetLiteral = Literal[\"RETAIL\"]\n",
        "\"\"\"Literal type for the retail node target.\"\"\"\n",
        "\n",
        "CUSTOMER_SERVICE_NODE_NAME = \"CUSTOMER_SERVICE\"\n",
        "\"\"\"The name of the customer service node in the LangGraph.\"\"\"\n",
        "CustomerServiceNodeTargetLiteral = Literal[\"CUSTOMER_SERVICE\"]\n",
        "\"\"\"Literal type for the customer service node target.\"\"\"\n",
        "\n",
        "POST_PROCESS_NODE_NAME = \"POST_PROCESS\"\n",
        "\"\"\"The name of the post-processing node in the LangGraph.\"\"\"\n",
        "PostProcessNodeTargetLiteral = Literal[\"POST_PROCESS\"]\n",
        "\"\"\"Literal type for the post-processing node target.\"\"\"\n",
        "\n",
        "EndNodeTargetLiteral = Literal[\"__end__\"]\n",
        "\"\"\"Literal type for the end node target.\"\"\"\n",
        "\n",
        "# Router classification\n",
        "\n",
        "\n",
        "class RouterTarget(enum.Enum):\n",
        "    \"\"\"Enumeration representing the possible targets for routing user queries.\"\"\"\n",
        "\n",
        "    customer_service = \"Customer Support Assistant\"\n",
        "    \"\"\"Target for customer service related queries.\"\"\"\n",
        "    retail_search = \"Conversational Retail Search Assistant\"\n",
        "    \"\"\"Target for retail search related queries.\"\"\"\n",
        "    unsupported = \"Unsupported\"\n",
        "    \"\"\"Target for unsupported queries.\"\"\"\n",
        "\n",
        "\n",
        "class RouterClassification(pydantic.BaseModel):\n",
        "    \"\"\"Structured classification output for routing user queries.\"\"\"\n",
        "\n",
        "    reason: str = pydantic.Field(\n",
        "        description=\"Reason for classifying the latest user query.\"\n",
        "    )\n",
        "    \"\"\"Explanation of why the query was classified to a specific target.\"\"\"\n",
        "    target: RouterTarget\n",
        "    \"\"\"The target node to route the query to.\"\"\"\n",
        "\n",
        "    model_config = pydantic.ConfigDict(\n",
        "        json_schema_extra={\"propertyOrdering\": [\"reason\", \"target\"]}\n",
        "    )\n",
        "    \"\"\"Configuration to specify the ordering of properties in the JSON schema.\"\"\"\n",
        "\n",
        "\n",
        "# LangGraph models\n",
        "\n",
        "\n",
        "class Turn(TypedDict, total=False):\n",
        "    \"\"\"\n",
        "    Represents a single turn in a conversation.\n",
        "\n",
        "    Attributes:\n",
        "        id: Unique identifier for the turn.\n",
        "        created_at: Timestamp of when the turn was created.\n",
        "        user_input: The user's input in this turn.\n",
        "        response: The agent's response in this turn, if any.\n",
        "        router_classification: The router classification for this turn, if any.\n",
        "        messages: A list of Gemini content messages associated with this turn.\n",
        "    \"\"\"\n",
        "\n",
        "    id: uuid.UUID\n",
        "    \"\"\"Unique identifier for the turn.\"\"\"\n",
        "\n",
        "    created_at: datetime.datetime\n",
        "    \"\"\"Timestamp of when the turn was created.\"\"\"\n",
        "\n",
        "    user_input: str\n",
        "    \"\"\"The user's input for this turn.\"\"\"\n",
        "\n",
        "    response: str\n",
        "    \"\"\"The agent's response for this turn, if any.\"\"\"\n",
        "\n",
        "    router_classification: RouterClassification | None\n",
        "    \"\"\"The router classification for this turn, if any.\"\"\"\n",
        "\n",
        "    messages: list[genai_types.Content]\n",
        "    \"\"\"List of Gemini Content objects representing the conversation messages in this turn.\"\"\"\n",
        "\n",
        "\n",
        "class GraphSession(TypedDict, total=False):\n",
        "    \"\"\"\n",
        "    Represents the complete state of a conversation session.\n",
        "\n",
        "    Attributes:\n",
        "        id: Unique identifier for the session.\n",
        "        created_at: Timestamp of when the session was created.\n",
        "        current_turn: The current turn in the session, if any.\n",
        "        turns: A list of all turns in the session.\n",
        "    \"\"\"\n",
        "\n",
        "    id: uuid.UUID\n",
        "    \"\"\"Unique identifier for the session.\"\"\"\n",
        "\n",
        "    created_at: datetime.datetime\n",
        "    \"\"\"Timestamp of when the session was created.\"\"\"\n",
        "\n",
        "    current_turn: Turn | None\n",
        "    \"\"\"The current conversation turn.\"\"\"\n",
        "\n",
        "    turns: list[Turn]\n",
        "    \"\"\"List of all conversation turns in the session.\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13cee5f003b8"
      },
      "source": [
        "### Nodes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f4f88084f38d"
      },
      "source": [
        "#### Router Node\n",
        "\n",
        "Classifies the current user input in context of the conversation and routes to the appropriate sub-agent. In case no sub-agent supports the user query, will respond with a fallback message."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "a724c4189b09"
      },
      "outputs": [],
      "source": [
        "UNSUPPORTED_FALLBACK_MESSAGE = \"I'm sorry, I am unable to process your request as it is outside of my current capabilities. Please try asking me about our retail business or customer support.\"\n",
        "\n",
        "ROUTER_SYSTEM_PROMPT = f\"\"\"\n",
        "You are an expert in classifying user queries for an agentic workflow for Cymbal, a retail company.\n",
        "First reason through how you will classify the query given the conversation history.\n",
        "Then, classify user queries to be sent to one of several AI assistants that can help the user.\n",
        "\n",
        "Classify every inputted query as: \"{RouterTarget.CUSTOMER_SERVICE.value}\", \"{RouterTarget.RETAIL_SEARCH.value}\", \"{RouterTarget.UNSUPPORTED.value}\".\n",
        "\n",
        "Class target descriptions:\n",
        "- \"{RouterTarget.RETAIL_SEARCH.value}\": Any pleasantries/general conversation or discussion of Cymbal retail products/stores/inventory, including live data.\n",
        "- \"{RouterTarget.CUSTOMER_SERVICE.value}\": Queries related to customer service such as item returns, policies, complaints, FAQs, escalations, etc.\n",
        "- \"{RouterTarget.UNSUPPORTED.value}\": Any query that is off topic or out of scope for one of the other agents.\n",
        "\n",
        "<examples>\n",
        "input: Is the Meinl Byzance Jazz Ride 18\" available?\n",
        "output: {RouterTarget.RETAIL_SEARCH.value}\n",
        "\n",
        "input: Recommend a good pair of running shoes.\n",
        "output: {RouterTarget.RETAIL_SEARCH.value}\n",
        "\n",
        "input: How do i initiate a return?\n",
        "output: {RouterTarget.CUSTOMER_SERVICE.value}\n",
        "\n",
        "input: you suck, why do you refuse to be useful!\n",
        "output: {RouterTarget.CUSTOMER_SERVICE.value}\n",
        "\n",
        "input: How far is the earth from the sun?\n",
        "output: {RouterTarget.UNSUPPORTED.value}\n",
        "\n",
        "input: What's the weather like today?\n",
        "output: {RouterTarget.UNSUPPORTED.value}\n",
        "</examples>\n",
        "\"\"\".strip()\n",
        "\n",
        "\n",
        "async def ainvoke_router(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[\n",
        "    Literal[\n",
        "        RetailNodeTargetLiteral,\n",
        "        CustomerServiceNodeTargetLiteral,\n",
        "        PostProcessNodeTargetLiteral,\n",
        "    ]\n",
        "]:\n",
        "    \"\"\"\n",
        "    Asynchronously invokes the router node to classify user input and determine the next action.\n",
        "\n",
        "    This function takes the current conversation state and configuration, interacts with the\n",
        "    Gemini model to classify the user's input based on predefined categories, and\n",
        "    determines which sub-agent should handle the request.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including user input and history.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (retail, customer service, or post-processing)\n",
        "        and the updated conversation state. This state includes the router classification.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    user_input = current_turn.get(\"user_input\")\n",
        "    assert user_input is not None, \"user input must be set\"\n",
        "\n",
        "    # Initialize generate model\n",
        "    client = genai.Client(\n",
        "        vertexai=True,\n",
        "        project=agent_config.project,\n",
        "        location=agent_config.region,\n",
        "    )\n",
        "\n",
        "    # Add new user input to history\n",
        "    turns = state.get(\"turns\", [])[: agent_config.max_router_turn_history]\n",
        "    history = [content for turn in turns for content in turn.get(\"messages\", [])]\n",
        "    user_content = genai_types.Content(\n",
        "        role=\"user\",\n",
        "        parts=[genai_types.Part.from_text(text=user_input)],\n",
        "    )\n",
        "    contents = history + [user_content]\n",
        "\n",
        "    # generate streaming response\n",
        "    response = await client.aio.models.generate_content(\n",
        "        model=agent_config.router_model_name,\n",
        "        contents=contents,\n",
        "        config=genai_types.GenerateContentConfig(\n",
        "            candidate_count=1,\n",
        "            temperature=0.2,\n",
        "            seed=0,\n",
        "            system_instruction=ROUTER_SYSTEM_PROMPT,\n",
        "            response_mime_type=\"application/json\",\n",
        "            response_schema=RouterClassification,\n",
        "        ),\n",
        "    )\n",
        "\n",
        "    router_classification = RouterClassification.model_validate_json(response.text)\n",
        "\n",
        "    stream_writer(\n",
        "        {\n",
        "            \"router_classification\": {\n",
        "                \"target\": router_classification.target.value,\n",
        "                \"reason\": router_classification.reason,\n",
        "            }\n",
        "        }\n",
        "    )\n",
        "\n",
        "    current_turn[\"router_classification\"] = router_classification\n",
        "\n",
        "    next_node = None\n",
        "    match router_classification.target:\n",
        "        case RouterTarget.RETAIL_SEARCH:\n",
        "            next_node = RETAIL_NODE_NAME\n",
        "        case RouterTarget.CUSTOMER_SERVICE:\n",
        "            next_node = CUSTOMER_SERVICE_NODE_NAME\n",
        "        case RouterTarget.UNSUPPORTED:\n",
        "            next_node = POST_PROCESS_NODE_NAME\n",
        "            current_turn[\"response\"] = UNSUPPORTED_FALLBACK_MESSAGE\n",
        "            stream_writer({\"text\": current_turn[\"response\"]})\n",
        "        case _:  # never\n",
        "            raise RuntimeError(\n",
        "                f\"Unhandled router classification target: {router_classification.target}\"\n",
        "            )\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=next_node,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f7009223ca3c"
      },
      "source": [
        "### Customer Service Node\n",
        "\n",
        "A mock sub-agent that will make up responses acting as the Cymbal Retail customer service agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "c6a8434cf8e6"
      },
      "outputs": [],
      "source": [
        "CUSTOMER_SERVICE_SYSTEM_PROMPT = \"Answer customer service questions about the Cymbal retail company. Cymbal offers both online retail and physical stores. Feel free to make up information about this fictional company, this is just for the purposes of a demo.\"\n",
        "\n",
        "\n",
        "async def ainvoke_customer_service(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[PostProcessNodeTargetLiteral]:\n",
        "    \"\"\"\n",
        "    Asynchronously invokes the customer service chat node to generate a response using a Gemini model.\n",
        "\n",
        "    This function takes the current conversation state and configuration, interacts with the\n",
        "    Gemini model to generate a customer service-oriented response based on the user's input\n",
        "    and conversation history, and streams the response back to the user.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including user input and history.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (post-processing) and the\n",
        "        updated conversation state. This state includes the model's customer service response and\n",
        "        the updated conversation history.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    user_input = current_turn.get(\"user_input\")\n",
        "    assert user_input is not None, \"user input must be set\"\n",
        "\n",
        "    # Initialize generate model\n",
        "    client = genai.Client(\n",
        "        vertexai=True,\n",
        "        project=agent_config.project,\n",
        "        location=agent_config.region,\n",
        "    )\n",
        "\n",
        "    # Add new user input to history\n",
        "    turns = state.get(\"turns\", [])\n",
        "    history = [content for turn in turns for content in turn.get(\"messages\", [])]\n",
        "    user_content = genai_types.Content(\n",
        "        role=\"user\",\n",
        "        parts=[genai_types.Part.from_text(text=user_input)],\n",
        "    )\n",
        "    contents = history + [user_content]\n",
        "\n",
        "    try:\n",
        "        # generate streaming response\n",
        "        response: AsyncIterator[genai_types.GenerateContentResponse] = (\n",
        "            await client.aio.models.generate_content_stream(\n",
        "                model=agent_config.chat_model_name,\n",
        "                contents=contents,\n",
        "                config=genai_types.GenerateContentConfig(\n",
        "                    candidate_count=1,\n",
        "                    temperature=0.2,\n",
        "                    seed=0,\n",
        "                    system_instruction=CUSTOMER_SERVICE_SYSTEM_PROMPT,\n",
        "                ),\n",
        "            )\n",
        "        )\n",
        "\n",
        "        # stream response text to custom stream writer\n",
        "        response_text = \"\"\n",
        "        async for chunk in response:\n",
        "            response_text += chunk.text\n",
        "            stream_writer({\"text\": chunk.text})\n",
        "\n",
        "        response_content = genai_types.Content(\n",
        "            role=\"model\",\n",
        "            parts=[genai_types.Part.from_text(text=response_text)],\n",
        "        )\n",
        "\n",
        "    except Exception as e:\n",
        "        logger.exception(e)\n",
        "        # unexpected error, display it\n",
        "        response_text = f\"An unexpected error occurred during generation, please try again.\\n\\nError = {str(e)}\"\n",
        "        stream_writer({\"error\": response_text})\n",
        "        response_content = genai_types.Content(\n",
        "            role=\"model\",\n",
        "            parts=[genai_types.Part.from_text(text=response_text)],\n",
        "        )\n",
        "\n",
        "    current_turn[\"response\"] = response_text.strip()\n",
        "    current_turn[\"messages\"] = [user_content, response_content]\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=POST_PROCESS_NODE_NAME,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1937b6a5fbf7"
      },
      "source": [
        "### Retail Search Node\n",
        "\n",
        "A mock sub-agent that will make up responses acting as the Cymbal Retail search assistant agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "be8156bf07ac"
      },
      "outputs": [],
      "source": [
        "RETAIL_SYSTEM_PROMPT = \"Have a conversation with the user and answer questions about the Cymbal retail company. Cymbal offers both online retail and physical stores. Feel free to make up information about this fictional company, this is just for the purposes of a demo.\"\n",
        "\n",
        "\n",
        "async def ainvoke_retail_search(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[PostProcessNodeTargetLiteral]:\n",
        "    \"\"\"\n",
        "    Asynchronously invokes the chat node to generate a response using a Gemini model.\n",
        "\n",
        "    This function takes the current conversation state and configuration, interacts with the\n",
        "    Gemini model to generate a response based on the user's input and conversation history,\n",
        "    and streams the response back to the user.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including user input and history.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (post-processing) and the\n",
        "        updated conversation state. This state includes the model's response and the updated\n",
        "        conversation history.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    user_input = current_turn.get(\"user_input\")\n",
        "    assert user_input is not None, \"user input must be set\"\n",
        "\n",
        "    # Initialize generate model\n",
        "    client = genai.Client(\n",
        "        vertexai=True,\n",
        "        project=agent_config.project,\n",
        "        location=agent_config.region,\n",
        "    )\n",
        "\n",
        "    # Add new user input to history\n",
        "    turns = state.get(\"turns\", [])\n",
        "    history = [content for turn in turns for content in turn.get(\"messages\", [])]\n",
        "    user_content = genai_types.Content(\n",
        "        role=\"user\",\n",
        "        parts=[genai_types.Part.from_text(text=user_input)],\n",
        "    )\n",
        "    contents = history + [user_content]\n",
        "\n",
        "    try:\n",
        "        # generate streaming response\n",
        "        response: AsyncIterator[genai_types.GenerateContentResponse] = (\n",
        "            await client.aio.models.generate_content_stream(\n",
        "                model=agent_config.chat_model_name,\n",
        "                contents=contents,\n",
        "                config=genai_types.GenerateContentConfig(\n",
        "                    candidate_count=1,\n",
        "                    temperature=0.2,\n",
        "                    seed=0,\n",
        "                    system_instruction=RETAIL_SYSTEM_PROMPT,\n",
        "                ),\n",
        "            )\n",
        "        )\n",
        "\n",
        "        # stream response text to custom stream writer\n",
        "        response_text = \"\"\n",
        "        async for chunk in response:\n",
        "            response_text += chunk.text\n",
        "            stream_writer({\"text\": chunk.text})\n",
        "\n",
        "        response_content = genai_types.Content(\n",
        "            role=\"model\",\n",
        "            parts=[genai_types.Part.from_text(text=response_text)],\n",
        "        )\n",
        "\n",
        "    except Exception as e:\n",
        "        logger.exception(e)\n",
        "        # unexpected error, display it\n",
        "        response_text = f\"An unexpected error occurred during generation, please try again.\\n\\nError = {str(e)}\"\n",
        "        stream_writer({\"error\": response_text})\n",
        "        response_content = genai_types.Content(\n",
        "            role=\"model\",\n",
        "            parts=[genai_types.Part.from_text(text=response_text)],\n",
        "        )\n",
        "\n",
        "    current_turn[\"response\"] = response_text.strip()\n",
        "    current_turn[\"messages\"] = [user_content, response_content]\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=POST_PROCESS_NODE_NAME,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1a332b28e263"
      },
      "source": [
        "#### Post-Process Node\n",
        "\n",
        "Add current turn to the history and reset current turn."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "1173d8589370"
      },
      "outputs": [],
      "source": [
        "async def ainvoke_post_process(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[EndNodeTargetLiteral]:\n",
        "    \"\"\"\n",
        "    Asynchronously invokes the post-processing node to finalize the current conversation turn.\n",
        "\n",
        "    This function takes the current conversation state, validates that the current turn and its response are set,\n",
        "    adds the completed turn to the conversation history, and resets the current turn. This effectively concludes\n",
        "    the processing of the current user input and prepares the session for the next input.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session.\n",
        "        config: The LangChain RunnableConfig (unused in this function).\n",
        "\n",
        "    Returns:\n",
        "        A Command object specifying the end of the graph execution and the updated conversation state.\n",
        "    \"\"\"\n",
        "\n",
        "    del config  # unused\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "\n",
        "    assert current_turn is not None, \"Current turn must be set.\"\n",
        "    assert (\n",
        "        current_turn[\"response\"] is not None\n",
        "    ), \"Response from current turn must be set.\"\n",
        "\n",
        "    turns = state.get(\"turns\", []) + [current_turn]\n",
        "\n",
        "    return lg_types.Command(update=GraphSession(current_turn=None, turns=turns))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ba5eecf61c0"
      },
      "source": [
        "## Compile Semantic Router Agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "113087763341"
      },
      "outputs": [],
      "source": [
        "def load_graph():\n",
        "    # Graph\n",
        "    state_graph = graph.StateGraph(state_schema=GraphSession)\n",
        "\n",
        "    # Nodes\n",
        "    state_graph.add_node(ROUTER_NODE_NAME, ainvoke_router)\n",
        "    state_graph.add_node(RETAIL_NODE_NAME, ainvoke_retail_search)\n",
        "    state_graph.add_node(CUSTOMER_SERVICE_NODE_NAME, ainvoke_customer_service)\n",
        "    state_graph.add_node(POST_PROCESS_NODE_NAME, ainvoke_post_process)\n",
        "    state_graph.set_entry_point(ROUTER_NODE_NAME)\n",
        "\n",
        "    return state_graph\n",
        "\n",
        "\n",
        "state_graph = load_graph()\n",
        "compiled_graph = state_graph.compile(memory_checkpoint.MemorySaver())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "012da4f35fcb"
      },
      "source": [
        "### Visualize agent graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "3fb9d5378f45"
      },
      "outputs": [],
      "source": [
        "display(ipd.Image(state_graph.compile().get_graph().draw_mermaid_png()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f207d60fb4eb"
      },
      "source": [
        "### Wrapper function to stream generation output to notebook"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "8af554de8e8e"
      },
      "outputs": [],
      "source": [
        "async def ask(user_input: str, session: str | None = None):\n",
        "    thread_id = session or uuid.uuid4().hex\n",
        "\n",
        "    agent_config = AgentConfig(\n",
        "        project=PROJECT_ID,\n",
        "        region=REGION,\n",
        "        chat_model_name=CHAT_MODEL_NAME,\n",
        "        router_model_name=ROUTER_MODEL_NAME,\n",
        "        max_router_turn_history=MAX_TURN_HISTORY,\n",
        "    )\n",
        "\n",
        "    current_source = last_source = None\n",
        "    all_text = \"\"\n",
        "    async for stream_mode, chunk in compiled_graph.astream(\n",
        "        input={\"current_turn\": {\"user_input\": user_input}},\n",
        "        config={\"configurable\": {\"thread_id\": thread_id, \"agent_config\": agent_config}},\n",
        "        stream_mode=[\"custom\"],\n",
        "    ):\n",
        "        assert isinstance(chunk, dict), \"Expected dictionary chunk\"\n",
        "\n",
        "        text = \"\"\n",
        "\n",
        "        if \"router_classification\" in chunk:\n",
        "            target = chunk[\"router_classification\"][\"target\"]\n",
        "            reason = chunk[\"router_classification\"][\"reason\"]\n",
        "\n",
        "            text = f\"Agent Classification: {target}\\n\\nReason: {reason}\"\n",
        "            current_source = \"router_classification\"\n",
        "\n",
        "        elif \"text\" in chunk:\n",
        "            text = chunk[\"text\"]\n",
        "            current_source = \"text\"\n",
        "\n",
        "        elif \"error\" in chunk:\n",
        "            text = chunk[\"error\"]\n",
        "            current_source = \"error\"\n",
        "\n",
        "        else:\n",
        "            print(\"unhandled chunk case:\", chunk)\n",
        "\n",
        "        if last_source is not None and last_source != current_source:\n",
        "            text = \"\\n\\n---\\n\\n\" + text\n",
        "\n",
        "        last_source = current_source\n",
        "\n",
        "        all_text += text\n",
        "        display(ipd.Markdown(all_text), clear=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "795c759b53cd"
      },
      "source": [
        "## Test Conversation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "0ce37bbb0ea7"
      },
      "outputs": [],
      "source": [
        "session = uuid.uuid4().hex"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "ab8bc629aa73"
      },
      "outputs": [],
      "source": [
        "await ask(\"What products do you offer?\", session=session)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "ec23f8ccb029"
      },
      "outputs": [],
      "source": [
        "await ask(\"can you summarize the plot of twilight?\", session=session)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "1216555e8dd6"
      },
      "outputs": [],
      "source": [
        "await ask(\"WTF you can't do anything, let me speak to a HUMAN!!!\", session=session)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "4a6d3e47290b"
      },
      "outputs": [],
      "source": [
        "await ask(\n",
        "    \"i just want to return this twilight dvd and you're being stupid\", session=session\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "semantic-router.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
